{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0cdfe535",
   "metadata": {
    "origin_pos": 0
   },
   "source": [
    "# 注意力评分函数\n",
    ":label:`sec_attention-scoring-functions`\n",
    "\n",
    " :numref:`sec_nadaraya-watson`使用了高斯核来对查询和键之间的关系建模。\n",
    " :eqref:`eq_nadaraya-watson-gaussian`中的\n",
    "高斯核指数部分可以视为*注意力评分函数*（attention scoring function），\n",
    "简称*评分函数*（scoring function），\n",
    "然后把这个函数的输出结果输入到softmax函数中进行运算。\n",
    "通过上述步骤，将得到与键对应的值的概率分布（即注意力权重）。\n",
    "最后，注意力汇聚的输出就是基于这些注意力权重的值的加权和。\n",
    "\n",
    "从宏观来看，上述算法可以用来实现\n",
    " :numref:`fig_qkv`中的注意力机制框架。\n",
    " :numref:`fig_attention_output`说明了\n",
    "如何将注意力汇聚的输出计算成为值的加权和，\n",
    "其中$a$表示注意力评分函数。\n",
    "由于注意力权重是概率分布，\n",
    "因此加权和其本质上是加权平均值。\n",
    "\n",
    "![计算注意力汇聚的输出为值的加权和](../img/attention-output.svg)\n",
    ":label:`fig_attention_output`\n",
    "\n",
    "用数学语言描述，假设有一个查询\n",
    "$\\mathbf{q} \\in \\mathbb{R}^q$和\n",
    "$m$个“键－值”对\n",
    "$(\\mathbf{k}_1, \\mathbf{v}_1), \\ldots, (\\mathbf{k}_m, \\mathbf{v}_m)$，\n",
    "其中$\\mathbf{k}_i \\in \\mathbb{R}^k$，$\\mathbf{v}_i \\in \\mathbb{R}^v$。\n",
    "注意力汇聚函数$f$就被表示成值的加权和：\n",
    "\n",
    "$$f(\\mathbf{q}, (\\mathbf{k}_1, \\mathbf{v}_1), \\ldots, (\\mathbf{k}_m, \\mathbf{v}_m)) = \\sum_{i=1}^m \\alpha(\\mathbf{q}, \\mathbf{k}_i) \\mathbf{v}_i \\in \\mathbb{R}^v,$$\n",
    ":eqlabel:`eq_attn-pooling`\n",
    "\n",
    "其中查询$\\mathbf{q}$和键$\\mathbf{k}_i$的注意力权重（标量）\n",
    "是通过注意力评分函数$a$将两个向量映射成标量，\n",
    "再经过softmax运算得到的：\n",
    "\n",
    "$$\\alpha(\\mathbf{q}, \\mathbf{k}_i) = \\mathrm{softmax}(a(\\mathbf{q}, \\mathbf{k}_i)) = \\frac{\\exp(a(\\mathbf{q}, \\mathbf{k}_i))}{\\sum_{j=1}^m \\exp(a(\\mathbf{q}, \\mathbf{k}_j))} \\in \\mathbb{R}.$$\n",
    ":eqlabel:`eq_attn-scoring-alpha`\n",
    "\n",
    "正如上图所示，选择不同的注意力评分函数$a$会导致不同的注意力汇聚操作。\n",
    "本节将介绍两个流行的评分函数，稍后将用他们来实现更复杂的注意力机制。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "77bfe8f0",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T06:58:57.046810Z",
     "iopub.status.busy": "2023-08-18T06:58:57.045981Z",
     "iopub.status.idle": "2023-08-18T06:59:00.174680Z",
     "shell.execute_reply": "2023-08-18T06:59:00.173514Z"
    },
    "origin_pos": 2,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "import math\n",
    "import torch\n",
    "from torch import nn\n",
    "from d2l import torch as d2l"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "937acec4",
   "metadata": {
    "origin_pos": 5
   },
   "source": [
    "## [**掩蔽softmax操作**]\n",
    "\n",
    "正如上面提到的，softmax操作用于输出一个概率分布作为注意力权重。\n",
    "在某些情况下，并非所有的值都应该被纳入到注意力汇聚中。\n",
    "例如，为了在 :numref:`sec_machine_translation`中高效处理小批量数据集，\n",
    "某些文本序列被填充了没有意义的特殊词元。\n",
    "为了仅将有意义的词元作为值来获取注意力汇聚，\n",
    "可以指定一个有效序列长度（即词元的个数），\n",
    "以便在计算softmax时过滤掉超出指定范围的位置。\n",
    "下面的`masked_softmax`函数\n",
    "实现了这样的*掩蔽softmax操作*（masked softmax operation），\n",
    "其中任何超出有效长度的位置都被掩蔽并置为0。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3be54330",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T06:59:00.180834Z",
     "iopub.status.busy": "2023-08-18T06:59:00.179926Z",
     "iopub.status.idle": "2023-08-18T06:59:00.189306Z",
     "shell.execute_reply": "2023-08-18T06:59:00.188185Z"
    },
    "origin_pos": 7,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "#@save\n",
    "def masked_softmax(X, valid_lens):\n",
    "    \"\"\"通过在最后一个轴上掩蔽元素来执行softmax操作\"\"\"\n",
    "    # X:3D张量，valid_lens:1D或2D张量\n",
    "    if valid_lens is None:\n",
    "        return nn.functional.softmax(X, dim=-1)\n",
    "    else:\n",
    "        shape = X.shape\n",
    "        if valid_lens.dim() == 1:\n",
    "            valid_lens = torch.repeat_interleave(valid_lens, shape[1])\n",
    "        else:\n",
    "            valid_lens = valid_lens.reshape(-1)\n",
    "        # 最后一轴上被掩蔽的元素使用一个非常大的负值替换，从而其softmax输出为0\n",
    "        X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens,\n",
    "                              value=-1e6)\n",
    "        return nn.functional.softmax(X.reshape(shape), dim=-1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ca8ddc66",
   "metadata": {
    "origin_pos": 10
   },
   "source": [
    "为了[**演示此函数是如何工作**]的，\n",
    "考虑由两个$2 \\times 4$矩阵表示的样本，\n",
    "这两个样本的有效长度分别为$2$和$3$。\n",
    "经过掩蔽softmax操作，超出有效长度的值都被掩蔽为0。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5f3b0b7f",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T06:59:00.194266Z",
     "iopub.status.busy": "2023-08-18T06:59:00.193661Z",
     "iopub.status.idle": "2023-08-18T06:59:00.238157Z",
     "shell.execute_reply": "2023-08-18T06:59:00.237124Z"
    },
    "origin_pos": 12,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[0.5980, 0.4020, 0.0000, 0.0000],\n",
       "         [0.5548, 0.4452, 0.0000, 0.0000]],\n",
       "\n",
       "        [[0.3716, 0.3926, 0.2358, 0.0000],\n",
       "         [0.3455, 0.3337, 0.3208, 0.0000]]])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "de7be477",
   "metadata": {
    "origin_pos": 15
   },
   "source": [
    "同样，也可以使用二维张量，为矩阵样本中的每一行指定有效长度。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "0296eee3",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T06:59:00.245983Z",
     "iopub.status.busy": "2023-08-18T06:59:00.244094Z",
     "iopub.status.idle": "2023-08-18T06:59:00.257237Z",
     "shell.execute_reply": "2023-08-18T06:59:00.256163Z"
    },
    "origin_pos": 17,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[1.0000, 0.0000, 0.0000, 0.0000],\n",
       "         [0.4125, 0.3273, 0.2602, 0.0000]],\n",
       "\n",
       "        [[0.5254, 0.4746, 0.0000, 0.0000],\n",
       "         [0.3117, 0.2130, 0.1801, 0.2952]]])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "masked_softmax(torch.rand(2, 2, 4), torch.tensor([[1, 3], [2, 4]]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f184f421",
   "metadata": {
    "origin_pos": 20
   },
   "source": [
    "## [**加性注意力**]\n",
    ":label:`subsec_additive-attention`\n",
    "\n",
    "一般来说，当查询和键是不同长度的矢量时，可以使用加性注意力作为评分函数。\n",
    "给定查询$\\mathbf{q} \\in \\mathbb{R}^q$和\n",
    "键$\\mathbf{k} \\in \\mathbb{R}^k$，\n",
    "*加性注意力*（additive attention）的评分函数为\n",
    "\n",
    "$$a(\\mathbf q, \\mathbf k) = \\mathbf w_v^\\top \\text{tanh}(\\mathbf W_q\\mathbf q + \\mathbf W_k \\mathbf k) \\in \\mathbb{R},$$\n",
    ":eqlabel:`eq_additive-attn`\n",
    "\n",
    "其中可学习的参数是$\\mathbf W_q\\in\\mathbb R^{h\\times q}$、\n",
    "$\\mathbf W_k\\in\\mathbb R^{h\\times k}$和\n",
    "$\\mathbf w_v\\in\\mathbb R^{h}$。\n",
    "如 :eqref:`eq_additive-attn`所示，\n",
    "将查询和键连结起来后输入到一个多层感知机（MLP）中，\n",
    "感知机包含一个隐藏层，其隐藏单元数是一个超参数$h$。\n",
    "通过使用$\\tanh$作为激活函数，并且禁用偏置项。\n",
    "\n",
    "下面来实现加性注意力。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c67a2f61",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T06:59:00.265207Z",
     "iopub.status.busy": "2023-08-18T06:59:00.263273Z",
     "iopub.status.idle": "2023-08-18T06:59:00.277358Z",
     "shell.execute_reply": "2023-08-18T06:59:00.276229Z"
    },
    "origin_pos": 22,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "#@save\n",
    "class AdditiveAttention(nn.Module):\n",
    "    \"\"\"加性注意力\"\"\"\n",
    "    def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):\n",
    "        super(AdditiveAttention, self).__init__(**kwargs)\n",
    "        self.W_k = nn.Linear(key_size, num_hiddens, bias=False)\n",
    "        self.W_q = nn.Linear(query_size, num_hiddens, bias=False)\n",
    "        self.w_v = nn.Linear(num_hiddens, 1, bias=False)\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "\n",
    "    def forward(self, queries, keys, values, valid_lens):\n",
    "        queries, keys = self.W_q(queries), self.W_k(keys)\n",
    "        # 在维度扩展后，\n",
    "        # queries的形状：(batch_size，查询的个数，1，num_hidden)\n",
    "        # key的形状：(batch_size，1，“键－值”对的个数，num_hiddens)\n",
    "        # 使用广播方式进行求和\n",
    "        features = queries.unsqueeze(2) + keys.unsqueeze(1)\n",
    "        features = torch.tanh(features)\n",
    "        # self.w_v仅有一个输出，因此从形状中移除最后那个维度。\n",
    "        # scores的形状：(batch_size，查询的个数，“键-值”对的个数)\n",
    "        scores = self.w_v(features).squeeze(-1)\n",
    "        self.attention_weights = masked_softmax(scores, valid_lens)\n",
    "        # values的形状：(batch_size，“键－值”对的个数，值的维度)\n",
    "        return torch.bmm(self.dropout(self.attention_weights), values)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c134134a",
   "metadata": {
    "origin_pos": 25
   },
   "source": [
    "用一个小例子来[**演示上面的`AdditiveAttention`类**]，\n",
    "其中查询、键和值的形状为（批量大小，步数或词元序列长度，特征大小），\n",
    "实际输出为$(2,1,20)$、$(2,10,2)$和$(2,10,4)$。\n",
    "注意力汇聚输出的形状为（批量大小，查询的步数，值的维度）。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "764a05d4",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T06:59:00.284992Z",
     "iopub.status.busy": "2023-08-18T06:59:00.283120Z",
     "iopub.status.idle": "2023-08-18T06:59:00.309249Z",
     "shell.execute_reply": "2023-08-18T06:59:00.308150Z"
    },
    "origin_pos": 27,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[ 2.0000,  3.0000,  4.0000,  5.0000]],\n",
       "\n",
       "        [[10.0000, 11.0000, 12.0000, 13.0000]]], grad_fn=<BmmBackward0>)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "queries, keys = torch.normal(0, 1, (2, 1, 20)), torch.ones((2, 10, 2))\n",
    "# values的小批量，两个值矩阵是相同的\n",
    "values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(\n",
    "    2, 1, 1)\n",
    "valid_lens = torch.tensor([2, 6])\n",
    "\n",
    "attention = AdditiveAttention(key_size=2, query_size=20, num_hiddens=8,\n",
    "                              dropout=0.1)\n",
    "attention.eval()\n",
    "attention(queries, keys, values, valid_lens)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "91b8ed50",
   "metadata": {
    "origin_pos": 30
   },
   "source": [
    "尽管加性注意力包含了可学习的参数，但由于本例子中每个键都是相同的，\n",
    "所以[**注意力权重**]是均匀的，由指定的有效长度决定。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "2cc35bc4",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T06:59:00.316794Z",
     "iopub.status.busy": "2023-08-18T06:59:00.315016Z",
     "iopub.status.idle": "2023-08-18T06:59:00.531804Z",
     "shell.execute_reply": "2023-08-18T06:59:00.530673Z"
    },
    "origin_pos": 31,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       "  \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"186.99575pt\" height=\"101.818906pt\" viewBox=\"0 0 186.99575 101.818906\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n",
       " <metadata>\n",
       "  <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n",
       "   <cc:Work>\n",
       "    <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n",
       "    <dc:date>2023-08-18T06:59:00.491487</dc:date>\n",
       "    <dc:format>image/svg+xml</dc:format>\n",
       "    <dc:creator>\n",
       "     <cc:Agent>\n",
       "      <dc:title>Matplotlib v3.5.1, https://matplotlib.org/</dc:title>\n",
       "     </cc:Agent>\n",
       "    </dc:creator>\n",
       "   </cc:Work>\n",
       "  </rdf:RDF>\n",
       " </metadata>\n",
       " <defs>\n",
       "  <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n",
       " </defs>\n",
       " <g id=\"figure_1\">\n",
       "  <g id=\"patch_1\">\n",
       "   <path d=\"M -0 101.818906 \n",
       "L 186.99575 101.818906 \n",
       "L 186.99575 0 \n",
       "L -0 0 \n",
       "L -0 101.818906 \n",
       "z\n",
       "\" style=\"fill: none\"/>\n",
       "  </g>\n",
       "  <g id=\"axes_1\">\n",
       "   <g id=\"patch_2\">\n",
       "    <path d=\"M 34.240625 59.13 \n",
       "L 145.840625 59.13 \n",
       "L 145.840625 36.81 \n",
       "L 34.240625 36.81 \n",
       "z\n",
       "\" style=\"fill: #ffffff\"/>\n",
       "   </g>\n",
       "   <g clip-path=\"url(#pc317cd931d)\">\n",
       "    <image xlink:href=\"data:image/png;base64,\n",
       "iVBORw0KGgoAAAANSUhEUgAAAHAAAAAXCAYAAADTEcupAAAAeElEQVR4nO3YsQmAQBAF0X+aWp8tGFqDXZjZnC2YCV5uJnIcA/MKWBaGTbbcx/ZESZJxXnuv8NnQewH9Y0A4A8IZEM6AcAaEMyCcAeEMCGdAOAPClSVTk1/ofp0txurFC4QzIJwB4QwIZ0A4A8IZEM6AcAaEMyBcBc2IB/NA+hblAAAAAElFTkSuQmCC\" id=\"imageea0a518805\" transform=\"scale(1 -1)translate(0 -23)\" x=\"34.240625\" y=\"-36.13\" width=\"112\" height=\"23\"/>\n",
       "   </g>\n",
       "   <g id=\"matplotlib.axis_1\">\n",
       "    <g id=\"xtick_1\">\n",
       "     <g id=\"line2d_1\">\n",
       "      <defs>\n",
       "       <path id=\"m92fab040e5\" d=\"M 0 0 \n",
       "L 0 3.5 \n",
       "\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use xlink:href=\"#m92fab040e5\" x=\"39.820625\" y=\"59.13\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_1\">\n",
       "      <!-- 0 -->\n",
       "      <g transform=\"translate(36.639375 73.728437)scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \n",
       "Q 1547 4250 1301 3770 \n",
       "Q 1056 3291 1056 2328 \n",
       "Q 1056 1369 1301 889 \n",
       "Q 1547 409 2034 409 \n",
       "Q 2525 409 2770 889 \n",
       "Q 3016 1369 3016 2328 \n",
       "Q 3016 3291 2770 3770 \n",
       "Q 2525 4250 2034 4250 \n",
       "z\n",
       "M 2034 4750 \n",
       "Q 2819 4750 3233 4129 \n",
       "Q 3647 3509 3647 2328 \n",
       "Q 3647 1150 3233 529 \n",
       "Q 2819 -91 2034 -91 \n",
       "Q 1250 -91 836 529 \n",
       "Q 422 1150 422 2328 \n",
       "Q 422 3509 836 4129 \n",
       "Q 1250 4750 2034 4750 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-30\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_2\">\n",
       "     <g id=\"line2d_2\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#m92fab040e5\" x=\"95.620625\" y=\"59.13\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_2\">\n",
       "      <!-- 5 -->\n",
       "      <g transform=\"translate(92.439375 73.728437)scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-35\" d=\"M 691 4666 \n",
       "L 3169 4666 \n",
       "L 3169 4134 \n",
       "L 1269 4134 \n",
       "L 1269 2991 \n",
       "Q 1406 3038 1543 3061 \n",
       "Q 1681 3084 1819 3084 \n",
       "Q 2600 3084 3056 2656 \n",
       "Q 3513 2228 3513 1497 \n",
       "Q 3513 744 3044 326 \n",
       "Q 2575 -91 1722 -91 \n",
       "Q 1428 -91 1123 -41 \n",
       "Q 819 9 494 109 \n",
       "L 494 744 \n",
       "Q 775 591 1075 516 \n",
       "Q 1375 441 1709 441 \n",
       "Q 2250 441 2565 725 \n",
       "Q 2881 1009 2881 1497 \n",
       "Q 2881 1984 2565 2268 \n",
       "Q 2250 2553 1709 2553 \n",
       "Q 1456 2553 1204 2497 \n",
       "Q 953 2441 691 2322 \n",
       "L 691 4666 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-35\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"text_3\">\n",
       "     <!-- Keys -->\n",
       "     <g transform=\"translate(78.371094 87.406562)scale(0.1 -0.1)\">\n",
       "      <defs>\n",
       "       <path id=\"DejaVuSans-4b\" d=\"M 628 4666 \n",
       "L 1259 4666 \n",
       "L 1259 2694 \n",
       "L 3353 4666 \n",
       "L 4166 4666 \n",
       "L 1850 2491 \n",
       "L 4331 0 \n",
       "L 3500 0 \n",
       "L 1259 2247 \n",
       "L 1259 0 \n",
       "L 628 0 \n",
       "L 628 4666 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-65\" d=\"M 3597 1894 \n",
       "L 3597 1613 \n",
       "L 953 1613 \n",
       "Q 991 1019 1311 708 \n",
       "Q 1631 397 2203 397 \n",
       "Q 2534 397 2845 478 \n",
       "Q 3156 559 3463 722 \n",
       "L 3463 178 \n",
       "Q 3153 47 2828 -22 \n",
       "Q 2503 -91 2169 -91 \n",
       "Q 1331 -91 842 396 \n",
       "Q 353 884 353 1716 \n",
       "Q 353 2575 817 3079 \n",
       "Q 1281 3584 2069 3584 \n",
       "Q 2775 3584 3186 3129 \n",
       "Q 3597 2675 3597 1894 \n",
       "z\n",
       "M 3022 2063 \n",
       "Q 3016 2534 2758 2815 \n",
       "Q 2500 3097 2075 3097 \n",
       "Q 1594 3097 1305 2825 \n",
       "Q 1016 2553 972 2059 \n",
       "L 3022 2063 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-79\" d=\"M 2059 -325 \n",
       "Q 1816 -950 1584 -1140 \n",
       "Q 1353 -1331 966 -1331 \n",
       "L 506 -1331 \n",
       "L 506 -850 \n",
       "L 844 -850 \n",
       "Q 1081 -850 1212 -737 \n",
       "Q 1344 -625 1503 -206 \n",
       "L 1606 56 \n",
       "L 191 3500 \n",
       "L 800 3500 \n",
       "L 1894 763 \n",
       "L 2988 3500 \n",
       "L 3597 3500 \n",
       "L 2059 -325 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-73\" d=\"M 2834 3397 \n",
       "L 2834 2853 \n",
       "Q 2591 2978 2328 3040 \n",
       "Q 2066 3103 1784 3103 \n",
       "Q 1356 3103 1142 2972 \n",
       "Q 928 2841 928 2578 \n",
       "Q 928 2378 1081 2264 \n",
       "Q 1234 2150 1697 2047 \n",
       "L 1894 2003 \n",
       "Q 2506 1872 2764 1633 \n",
       "Q 3022 1394 3022 966 \n",
       "Q 3022 478 2636 193 \n",
       "Q 2250 -91 1575 -91 \n",
       "Q 1294 -91 989 -36 \n",
       "Q 684 19 347 128 \n",
       "L 347 722 \n",
       "Q 666 556 975 473 \n",
       "Q 1284 391 1588 391 \n",
       "Q 1994 391 2212 530 \n",
       "Q 2431 669 2431 922 \n",
       "Q 2431 1156 2273 1281 \n",
       "Q 2116 1406 1581 1522 \n",
       "L 1381 1569 \n",
       "Q 847 1681 609 1914 \n",
       "Q 372 2147 372 2553 \n",
       "Q 372 3047 722 3315 \n",
       "Q 1072 3584 1716 3584 \n",
       "Q 2034 3584 2315 3537 \n",
       "Q 2597 3491 2834 3397 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "      </defs>\n",
       "      <use xlink:href=\"#DejaVuSans-4b\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-65\" x=\"60.576172\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-79\" x=\"122.099609\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-73\" x=\"181.279297\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"matplotlib.axis_2\">\n",
       "    <g id=\"ytick_1\">\n",
       "     <g id=\"line2d_3\">\n",
       "      <defs>\n",
       "       <path id=\"m15a60364a0\" d=\"M 0 0 \n",
       "L -3.5 0 \n",
       "\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use xlink:href=\"#m15a60364a0\" x=\"34.240625\" y=\"42.39\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_4\">\n",
       "      <!-- 0 -->\n",
       "      <g transform=\"translate(20.878125 46.189219)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-30\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_2\">\n",
       "     <g id=\"line2d_4\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#m15a60364a0\" x=\"34.240625\" y=\"53.55\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_5\">\n",
       "      <!-- 1 -->\n",
       "      <g transform=\"translate(20.878125 57.349219)scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-31\" d=\"M 794 531 \n",
       "L 1825 531 \n",
       "L 1825 4091 \n",
       "L 703 3866 \n",
       "L 703 4441 \n",
       "L 1819 4666 \n",
       "L 2450 4666 \n",
       "L 2450 531 \n",
       "L 3481 531 \n",
       "L 3481 0 \n",
       "L 794 0 \n",
       "L 794 531 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-31\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"text_6\">\n",
       "     <!-- Queries -->\n",
       "     <g transform=\"translate(14.798437 67.277031)rotate(-90)scale(0.1 -0.1)\">\n",
       "      <defs>\n",
       "       <path id=\"DejaVuSans-51\" d=\"M 2522 4238 \n",
       "Q 1834 4238 1429 3725 \n",
       "Q 1025 3213 1025 2328 \n",
       "Q 1025 1447 1429 934 \n",
       "Q 1834 422 2522 422 \n",
       "Q 3209 422 3611 934 \n",
       "Q 4013 1447 4013 2328 \n",
       "Q 4013 3213 3611 3725 \n",
       "Q 3209 4238 2522 4238 \n",
       "z\n",
       "M 3406 84 \n",
       "L 4238 -825 \n",
       "L 3475 -825 \n",
       "L 2784 -78 \n",
       "Q 2681 -84 2626 -87 \n",
       "Q 2572 -91 2522 -91 \n",
       "Q 1538 -91 948 567 \n",
       "Q 359 1225 359 2328 \n",
       "Q 359 3434 948 4092 \n",
       "Q 1538 4750 2522 4750 \n",
       "Q 3503 4750 4090 4092 \n",
       "Q 4678 3434 4678 2328 \n",
       "Q 4678 1516 4351 937 \n",
       "Q 4025 359 3406 84 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-75\" d=\"M 544 1381 \n",
       "L 544 3500 \n",
       "L 1119 3500 \n",
       "L 1119 1403 \n",
       "Q 1119 906 1312 657 \n",
       "Q 1506 409 1894 409 \n",
       "Q 2359 409 2629 706 \n",
       "Q 2900 1003 2900 1516 \n",
       "L 2900 3500 \n",
       "L 3475 3500 \n",
       "L 3475 0 \n",
       "L 2900 0 \n",
       "L 2900 538 \n",
       "Q 2691 219 2414 64 \n",
       "Q 2138 -91 1772 -91 \n",
       "Q 1169 -91 856 284 \n",
       "Q 544 659 544 1381 \n",
       "z\n",
       "M 1991 3584 \n",
       "L 1991 3584 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-72\" d=\"M 2631 2963 \n",
       "Q 2534 3019 2420 3045 \n",
       "Q 2306 3072 2169 3072 \n",
       "Q 1681 3072 1420 2755 \n",
       "Q 1159 2438 1159 1844 \n",
       "L 1159 0 \n",
       "L 581 0 \n",
       "L 581 3500 \n",
       "L 1159 3500 \n",
       "L 1159 2956 \n",
       "Q 1341 3275 1631 3429 \n",
       "Q 1922 3584 2338 3584 \n",
       "Q 2397 3584 2469 3576 \n",
       "Q 2541 3569 2628 3553 \n",
       "L 2631 2963 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-69\" d=\"M 603 3500 \n",
       "L 1178 3500 \n",
       "L 1178 0 \n",
       "L 603 0 \n",
       "L 603 3500 \n",
       "z\n",
       "M 603 4863 \n",
       "L 1178 4863 \n",
       "L 1178 4134 \n",
       "L 603 4134 \n",
       "L 603 4863 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "      </defs>\n",
       "      <use xlink:href=\"#DejaVuSans-51\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-75\" x=\"78.710938\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-65\" x=\"142.089844\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-72\" x=\"203.613281\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-69\" x=\"244.726562\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-65\" x=\"272.509766\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-73\" x=\"334.033203\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"patch_3\">\n",
       "    <path d=\"M 34.240625 59.13 \n",
       "L 34.240625 36.81 \n",
       "\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_4\">\n",
       "    <path d=\"M 145.840625 59.13 \n",
       "L 145.840625 36.81 \n",
       "\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_5\">\n",
       "    <path d=\"M 34.240625 59.13 \n",
       "L 145.840625 59.13 \n",
       "\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_6\">\n",
       "    <path d=\"M 34.240625 36.81 \n",
       "L 145.840625 36.81 \n",
       "\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "  </g>\n",
       "  <g id=\"axes_2\">\n",
       "   <g id=\"patch_7\">\n",
       "    <path d=\"M 152.815625 88.74 \n",
       "L 156.892625 88.74 \n",
       "L 156.892625 7.2 \n",
       "L 152.815625 7.2 \n",
       "z\n",
       "\" style=\"fill: #ffffff\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_8\">\n",
       "    <path clip-path=\"url(#p1f870602a9)\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.01; stroke-linejoin: miter\"/>\n",
       "   </g>\n",
       "   <image xlink:href=\"data:image/png;base64,\n",
       "iVBORw0KGgoAAAANSUhEUgAAAAQAAABSCAYAAABzJnWUAAAAm0lEQVR4nJ2SOw7CUBADHxL3vyoNBeyXFsZIFqQcje1NlMveb3venuuZPg6MMXbKdvxcqqASoDG7+TSG7xBjm5GSSLgVjXiDQO4Izvq39SsepHxC/g5/lLJDjA2C4mzKHQ5MjjF01q50obR7P0E1jYIhkZRSgmKk7UoNI8tZgiLA6hdDI0cM1yF3BE//YqA0xCB4SClB8FK5g6UvE4PMuTPJ8jwAAAAASUVORK5CYII=\" id=\"image8ec47ac8a1\" transform=\"scale(1 -1)translate(0 -82)\" x=\"153\" y=\"-6\" width=\"4\" height=\"82\"/>\n",
       "   <g id=\"matplotlib.axis_3\">\n",
       "    <g id=\"ytick_3\">\n",
       "     <g id=\"line2d_5\">\n",
       "      <defs>\n",
       "       <path id=\"m2176ad570c\" d=\"M 0 0 \n",
       "L 3.5 0 \n",
       "\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use xlink:href=\"#m2176ad570c\" x=\"156.892625\" y=\"88.74\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_7\">\n",
       "      <!-- 0.0 -->\n",
       "      <g transform=\"translate(163.892625 92.539219)scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-2e\" d=\"M 684 794 \n",
       "L 1344 794 \n",
       "L 1344 0 \n",
       "L 684 0 \n",
       "L 684 794 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-30\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-30\" x=\"95.410156\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_4\">\n",
       "     <g id=\"line2d_6\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#m2176ad570c\" x=\"156.892625\" y=\"56.124\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_8\">\n",
       "      <!-- 0.2 -->\n",
       "      <g transform=\"translate(163.892625 59.923219)scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-32\" d=\"M 1228 531 \n",
       "L 3431 531 \n",
       "L 3431 0 \n",
       "L 469 0 \n",
       "L 469 531 \n",
       "Q 828 903 1448 1529 \n",
       "Q 2069 2156 2228 2338 \n",
       "Q 2531 2678 2651 2914 \n",
       "Q 2772 3150 2772 3378 \n",
       "Q 2772 3750 2511 3984 \n",
       "Q 2250 4219 1831 4219 \n",
       "Q 1534 4219 1204 4116 \n",
       "Q 875 4013 500 3803 \n",
       "L 500 4441 \n",
       "Q 881 4594 1212 4672 \n",
       "Q 1544 4750 1819 4750 \n",
       "Q 2544 4750 2975 4387 \n",
       "Q 3406 4025 3406 3419 \n",
       "Q 3406 3131 3298 2873 \n",
       "Q 3191 2616 2906 2266 \n",
       "Q 2828 2175 2409 1742 \n",
       "Q 1991 1309 1228 531 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-30\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-32\" x=\"95.410156\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_5\">\n",
       "     <g id=\"line2d_7\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#m2176ad570c\" x=\"156.892625\" y=\"23.508\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_9\">\n",
       "      <!-- 0.4 -->\n",
       "      <g transform=\"translate(163.892625 27.307219)scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-34\" d=\"M 2419 4116 \n",
       "L 825 1625 \n",
       "L 2419 1625 \n",
       "L 2419 4116 \n",
       "z\n",
       "M 2253 4666 \n",
       "L 3047 4666 \n",
       "L 3047 1625 \n",
       "L 3713 1625 \n",
       "L 3713 1100 \n",
       "L 3047 1100 \n",
       "L 3047 0 \n",
       "L 2419 0 \n",
       "L 2419 1100 \n",
       "L 313 1100 \n",
       "L 313 1709 \n",
       "L 2253 4666 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-30\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-34\" x=\"95.410156\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"LineCollection_1\"/>\n",
       "   <g id=\"patch_9\">\n",
       "    <path d=\"M 152.815625 88.74 \n",
       "L 154.854125 88.74 \n",
       "L 156.892625 88.74 \n",
       "L 156.892625 7.2 \n",
       "L 154.854125 7.2 \n",
       "L 152.815625 7.2 \n",
       "L 152.815625 88.74 \n",
       "z\n",
       "\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "  </g>\n",
       " </g>\n",
       " <defs>\n",
       "  <clipPath id=\"pc317cd931d\">\n",
       "   <rect x=\"34.240625\" y=\"36.81\" width=\"111.6\" height=\"22.32\"/>\n",
       "  </clipPath>\n",
       "  <clipPath id=\"p1f870602a9\">\n",
       "   <rect x=\"152.815625\" y=\"7.2\" width=\"4.077\" height=\"81.54\"/>\n",
       "  </clipPath>\n",
       " </defs>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<Figure size 180x180 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),\n",
    "                  xlabel='Keys', ylabel='Queries')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "efae8fbc",
   "metadata": {
    "origin_pos": 32
   },
   "source": [
    "## [**缩放点积注意力**]\n",
    "\n",
    "使用点积可以得到计算效率更高的评分函数，\n",
    "但是点积操作要求查询和键具有相同的长度$d$。\n",
    "假设查询和键的所有元素都是独立的随机变量，\n",
    "并且都满足零均值和单位方差，\n",
    "那么两个向量的点积的均值为$0$，方差为$d$。\n",
    "为确保无论向量长度如何，\n",
    "点积的方差在不考虑向量长度的情况下仍然是$1$，\n",
    "我们再将点积除以$\\sqrt{d}$，\n",
    "则*缩放点积注意力*（scaled dot-product attention）评分函数为：\n",
    "\n",
    "$$a(\\mathbf q, \\mathbf k) = \\mathbf{q}^\\top \\mathbf{k}  /\\sqrt{d}.$$\n",
    "\n",
    "在实践中，我们通常从小批量的角度来考虑提高效率，\n",
    "例如基于$n$个查询和$m$个键－值对计算注意力，\n",
    "其中查询和键的长度为$d$，值的长度为$v$。\n",
    "查询$\\mathbf Q\\in\\mathbb R^{n\\times d}$、\n",
    "键$\\mathbf K\\in\\mathbb R^{m\\times d}$和\n",
    "值$\\mathbf V\\in\\mathbb R^{m\\times v}$的缩放点积注意力是：\n",
    "\n",
    "$$ \\mathrm{softmax}\\left(\\frac{\\mathbf Q \\mathbf K^\\top }{\\sqrt{d}}\\right) \\mathbf V \\in \\mathbb{R}^{n\\times v}.$$\n",
    ":eqlabel:`eq_softmax_QK_V`\n",
    "\n",
    "下面的缩放点积注意力的实现使用了暂退法进行模型正则化。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "974ab3dd",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T06:59:00.537439Z",
     "iopub.status.busy": "2023-08-18T06:59:00.536367Z",
     "iopub.status.idle": "2023-08-18T06:59:00.545728Z",
     "shell.execute_reply": "2023-08-18T06:59:00.544563Z"
    },
    "origin_pos": 34,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "#@save\n",
    "class DotProductAttention(nn.Module):\n",
    "    \"\"\"缩放点积注意力\"\"\"\n",
    "    def __init__(self, dropout, **kwargs):\n",
    "        super(DotProductAttention, self).__init__(**kwargs)\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "\n",
    "    # queries的形状：(batch_size，查询的个数，d)\n",
    "    # keys的形状：(batch_size，“键－值”对的个数，d)\n",
    "    # values的形状：(batch_size，“键－值”对的个数，值的维度)\n",
    "    # valid_lens的形状:(batch_size，)或者(batch_size，查询的个数)\n",
    "    def forward(self, queries, keys, values, valid_lens=None):\n",
    "        d = queries.shape[-1]\n",
    "        # 设置transpose_b=True为了交换keys的最后两个维度\n",
    "        scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d)\n",
    "        self.attention_weights = masked_softmax(scores, valid_lens)\n",
    "        return torch.bmm(self.dropout(self.attention_weights), values)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "78bb92bb",
   "metadata": {
    "origin_pos": 37
   },
   "source": [
    "为了[**演示上述的`DotProductAttention`类**]，\n",
    "我们使用与先前加性注意力例子中相同的键、值和有效长度。\n",
    "对于点积操作，我们令查询的特征维度与键的特征维度大小相同。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "f4eb4277",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T06:59:00.550788Z",
     "iopub.status.busy": "2023-08-18T06:59:00.549812Z",
     "iopub.status.idle": "2023-08-18T06:59:00.562245Z",
     "shell.execute_reply": "2023-08-18T06:59:00.561174Z"
    },
    "origin_pos": 39,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[ 2.0000,  3.0000,  4.0000,  5.0000]],\n",
       "\n",
       "        [[10.0000, 11.0000, 12.0000, 13.0000]]])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "queries = torch.normal(0, 1, (2, 1, 2))\n",
    "attention = DotProductAttention(dropout=0.5)\n",
    "attention.eval()\n",
    "attention(queries, keys, values, valid_lens)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f4b6a04e",
   "metadata": {
    "origin_pos": 42
   },
   "source": [
    "与加性注意力演示相同，由于键包含的是相同的元素，\n",
    "而这些元素无法通过任何查询进行区分，因此获得了[**均匀的注意力权重**]。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "76040da6",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T06:59:00.567093Z",
     "iopub.status.busy": "2023-08-18T06:59:00.566474Z",
     "iopub.status.idle": "2023-08-18T06:59:00.804899Z",
     "shell.execute_reply": "2023-08-18T06:59:00.803678Z"
    },
    "origin_pos": 43,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       "  \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"186.99575pt\" height=\"101.818906pt\" viewBox=\"0 0 186.99575 101.818906\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n",
       " <metadata>\n",
       "  <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n",
       "   <cc:Work>\n",
       "    <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n",
       "    <dc:date>2023-08-18T06:59:00.757633</dc:date>\n",
       "    <dc:format>image/svg+xml</dc:format>\n",
       "    <dc:creator>\n",
       "     <cc:Agent>\n",
       "      <dc:title>Matplotlib v3.5.1, https://matplotlib.org/</dc:title>\n",
       "     </cc:Agent>\n",
       "    </dc:creator>\n",
       "   </cc:Work>\n",
       "  </rdf:RDF>\n",
       " </metadata>\n",
       " <defs>\n",
       "  <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n",
       " </defs>\n",
       " <g id=\"figure_1\">\n",
       "  <g id=\"patch_1\">\n",
       "   <path d=\"M -0 101.818906 \n",
       "L 186.99575 101.818906 \n",
       "L 186.99575 0 \n",
       "L -0 0 \n",
       "L -0 101.818906 \n",
       "z\n",
       "\" style=\"fill: none\"/>\n",
       "  </g>\n",
       "  <g id=\"axes_1\">\n",
       "   <g id=\"patch_2\">\n",
       "    <path d=\"M 34.240625 59.13 \n",
       "L 145.840625 59.13 \n",
       "L 145.840625 36.81 \n",
       "L 34.240625 36.81 \n",
       "z\n",
       "\" style=\"fill: #ffffff\"/>\n",
       "   </g>\n",
       "   <g clip-path=\"url(#pf0fd76f9ae)\">\n",
       "    <image xlink:href=\"data:image/png;base64,\n",
       "iVBORw0KGgoAAAANSUhEUgAAAHAAAAAXCAYAAADTEcupAAAAeElEQVR4nO3YsQmAQBAF0X+aWp8tGFqDXZjZnC2YCV5uJnIcA/MKWBaGTbbcx/ZESZJxXnuv8NnQewH9Y0A4A8IZEM6AcAaEMyCcAeEMCGdAOAPClSVTk1/ofp0txurFC4QzIJwB4QwIZ0A4A8IZEM6AcAaEMyBcBc2IB/NA+hblAAAAAElFTkSuQmCC\" id=\"image796dfede78\" transform=\"scale(1 -1)translate(0 -23)\" x=\"34.240625\" y=\"-36.13\" width=\"112\" height=\"23\"/>\n",
       "   </g>\n",
       "   <g id=\"matplotlib.axis_1\">\n",
       "    <g id=\"xtick_1\">\n",
       "     <g id=\"line2d_1\">\n",
       "      <defs>\n",
       "       <path id=\"m9118eec785\" d=\"M 0 0 \n",
       "L 0 3.5 \n",
       "\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use xlink:href=\"#m9118eec785\" x=\"39.820625\" y=\"59.13\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_1\">\n",
       "      <!-- 0 -->\n",
       "      <g transform=\"translate(36.639375 73.728437)scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \n",
       "Q 1547 4250 1301 3770 \n",
       "Q 1056 3291 1056 2328 \n",
       "Q 1056 1369 1301 889 \n",
       "Q 1547 409 2034 409 \n",
       "Q 2525 409 2770 889 \n",
       "Q 3016 1369 3016 2328 \n",
       "Q 3016 3291 2770 3770 \n",
       "Q 2525 4250 2034 4250 \n",
       "z\n",
       "M 2034 4750 \n",
       "Q 2819 4750 3233 4129 \n",
       "Q 3647 3509 3647 2328 \n",
       "Q 3647 1150 3233 529 \n",
       "Q 2819 -91 2034 -91 \n",
       "Q 1250 -91 836 529 \n",
       "Q 422 1150 422 2328 \n",
       "Q 422 3509 836 4129 \n",
       "Q 1250 4750 2034 4750 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-30\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_2\">\n",
       "     <g id=\"line2d_2\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#m9118eec785\" x=\"95.620625\" y=\"59.13\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_2\">\n",
       "      <!-- 5 -->\n",
       "      <g transform=\"translate(92.439375 73.728437)scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-35\" d=\"M 691 4666 \n",
       "L 3169 4666 \n",
       "L 3169 4134 \n",
       "L 1269 4134 \n",
       "L 1269 2991 \n",
       "Q 1406 3038 1543 3061 \n",
       "Q 1681 3084 1819 3084 \n",
       "Q 2600 3084 3056 2656 \n",
       "Q 3513 2228 3513 1497 \n",
       "Q 3513 744 3044 326 \n",
       "Q 2575 -91 1722 -91 \n",
       "Q 1428 -91 1123 -41 \n",
       "Q 819 9 494 109 \n",
       "L 494 744 \n",
       "Q 775 591 1075 516 \n",
       "Q 1375 441 1709 441 \n",
       "Q 2250 441 2565 725 \n",
       "Q 2881 1009 2881 1497 \n",
       "Q 2881 1984 2565 2268 \n",
       "Q 2250 2553 1709 2553 \n",
       "Q 1456 2553 1204 2497 \n",
       "Q 953 2441 691 2322 \n",
       "L 691 4666 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-35\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"text_3\">\n",
       "     <!-- Keys -->\n",
       "     <g transform=\"translate(78.371094 87.406562)scale(0.1 -0.1)\">\n",
       "      <defs>\n",
       "       <path id=\"DejaVuSans-4b\" d=\"M 628 4666 \n",
       "L 1259 4666 \n",
       "L 1259 2694 \n",
       "L 3353 4666 \n",
       "L 4166 4666 \n",
       "L 1850 2491 \n",
       "L 4331 0 \n",
       "L 3500 0 \n",
       "L 1259 2247 \n",
       "L 1259 0 \n",
       "L 628 0 \n",
       "L 628 4666 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-65\" d=\"M 3597 1894 \n",
       "L 3597 1613 \n",
       "L 953 1613 \n",
       "Q 991 1019 1311 708 \n",
       "Q 1631 397 2203 397 \n",
       "Q 2534 397 2845 478 \n",
       "Q 3156 559 3463 722 \n",
       "L 3463 178 \n",
       "Q 3153 47 2828 -22 \n",
       "Q 2503 -91 2169 -91 \n",
       "Q 1331 -91 842 396 \n",
       "Q 353 884 353 1716 \n",
       "Q 353 2575 817 3079 \n",
       "Q 1281 3584 2069 3584 \n",
       "Q 2775 3584 3186 3129 \n",
       "Q 3597 2675 3597 1894 \n",
       "z\n",
       "M 3022 2063 \n",
       "Q 3016 2534 2758 2815 \n",
       "Q 2500 3097 2075 3097 \n",
       "Q 1594 3097 1305 2825 \n",
       "Q 1016 2553 972 2059 \n",
       "L 3022 2063 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-79\" d=\"M 2059 -325 \n",
       "Q 1816 -950 1584 -1140 \n",
       "Q 1353 -1331 966 -1331 \n",
       "L 506 -1331 \n",
       "L 506 -850 \n",
       "L 844 -850 \n",
       "Q 1081 -850 1212 -737 \n",
       "Q 1344 -625 1503 -206 \n",
       "L 1606 56 \n",
       "L 191 3500 \n",
       "L 800 3500 \n",
       "L 1894 763 \n",
       "L 2988 3500 \n",
       "L 3597 3500 \n",
       "L 2059 -325 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-73\" d=\"M 2834 3397 \n",
       "L 2834 2853 \n",
       "Q 2591 2978 2328 3040 \n",
       "Q 2066 3103 1784 3103 \n",
       "Q 1356 3103 1142 2972 \n",
       "Q 928 2841 928 2578 \n",
       "Q 928 2378 1081 2264 \n",
       "Q 1234 2150 1697 2047 \n",
       "L 1894 2003 \n",
       "Q 2506 1872 2764 1633 \n",
       "Q 3022 1394 3022 966 \n",
       "Q 3022 478 2636 193 \n",
       "Q 2250 -91 1575 -91 \n",
       "Q 1294 -91 989 -36 \n",
       "Q 684 19 347 128 \n",
       "L 347 722 \n",
       "Q 666 556 975 473 \n",
       "Q 1284 391 1588 391 \n",
       "Q 1994 391 2212 530 \n",
       "Q 2431 669 2431 922 \n",
       "Q 2431 1156 2273 1281 \n",
       "Q 2116 1406 1581 1522 \n",
       "L 1381 1569 \n",
       "Q 847 1681 609 1914 \n",
       "Q 372 2147 372 2553 \n",
       "Q 372 3047 722 3315 \n",
       "Q 1072 3584 1716 3584 \n",
       "Q 2034 3584 2315 3537 \n",
       "Q 2597 3491 2834 3397 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "      </defs>\n",
       "      <use xlink:href=\"#DejaVuSans-4b\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-65\" x=\"60.576172\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-79\" x=\"122.099609\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-73\" x=\"181.279297\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"matplotlib.axis_2\">\n",
       "    <g id=\"ytick_1\">\n",
       "     <g id=\"line2d_3\">\n",
       "      <defs>\n",
       "       <path id=\"mdf7287543b\" d=\"M 0 0 \n",
       "L -3.5 0 \n",
       "\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use xlink:href=\"#mdf7287543b\" x=\"34.240625\" y=\"42.39\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_4\">\n",
       "      <!-- 0 -->\n",
       "      <g transform=\"translate(20.878125 46.189219)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-30\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_2\">\n",
       "     <g id=\"line2d_4\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#mdf7287543b\" x=\"34.240625\" y=\"53.55\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_5\">\n",
       "      <!-- 1 -->\n",
       "      <g transform=\"translate(20.878125 57.349219)scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-31\" d=\"M 794 531 \n",
       "L 1825 531 \n",
       "L 1825 4091 \n",
       "L 703 3866 \n",
       "L 703 4441 \n",
       "L 1819 4666 \n",
       "L 2450 4666 \n",
       "L 2450 531 \n",
       "L 3481 531 \n",
       "L 3481 0 \n",
       "L 794 0 \n",
       "L 794 531 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-31\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"text_6\">\n",
       "     <!-- Queries -->\n",
       "     <g transform=\"translate(14.798437 67.277031)rotate(-90)scale(0.1 -0.1)\">\n",
       "      <defs>\n",
       "       <path id=\"DejaVuSans-51\" d=\"M 2522 4238 \n",
       "Q 1834 4238 1429 3725 \n",
       "Q 1025 3213 1025 2328 \n",
       "Q 1025 1447 1429 934 \n",
       "Q 1834 422 2522 422 \n",
       "Q 3209 422 3611 934 \n",
       "Q 4013 1447 4013 2328 \n",
       "Q 4013 3213 3611 3725 \n",
       "Q 3209 4238 2522 4238 \n",
       "z\n",
       "M 3406 84 \n",
       "L 4238 -825 \n",
       "L 3475 -825 \n",
       "L 2784 -78 \n",
       "Q 2681 -84 2626 -87 \n",
       "Q 2572 -91 2522 -91 \n",
       "Q 1538 -91 948 567 \n",
       "Q 359 1225 359 2328 \n",
       "Q 359 3434 948 4092 \n",
       "Q 1538 4750 2522 4750 \n",
       "Q 3503 4750 4090 4092 \n",
       "Q 4678 3434 4678 2328 \n",
       "Q 4678 1516 4351 937 \n",
       "Q 4025 359 3406 84 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-75\" d=\"M 544 1381 \n",
       "L 544 3500 \n",
       "L 1119 3500 \n",
       "L 1119 1403 \n",
       "Q 1119 906 1312 657 \n",
       "Q 1506 409 1894 409 \n",
       "Q 2359 409 2629 706 \n",
       "Q 2900 1003 2900 1516 \n",
       "L 2900 3500 \n",
       "L 3475 3500 \n",
       "L 3475 0 \n",
       "L 2900 0 \n",
       "L 2900 538 \n",
       "Q 2691 219 2414 64 \n",
       "Q 2138 -91 1772 -91 \n",
       "Q 1169 -91 856 284 \n",
       "Q 544 659 544 1381 \n",
       "z\n",
       "M 1991 3584 \n",
       "L 1991 3584 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-72\" d=\"M 2631 2963 \n",
       "Q 2534 3019 2420 3045 \n",
       "Q 2306 3072 2169 3072 \n",
       "Q 1681 3072 1420 2755 \n",
       "Q 1159 2438 1159 1844 \n",
       "L 1159 0 \n",
       "L 581 0 \n",
       "L 581 3500 \n",
       "L 1159 3500 \n",
       "L 1159 2956 \n",
       "Q 1341 3275 1631 3429 \n",
       "Q 1922 3584 2338 3584 \n",
       "Q 2397 3584 2469 3576 \n",
       "Q 2541 3569 2628 3553 \n",
       "L 2631 2963 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-69\" d=\"M 603 3500 \n",
       "L 1178 3500 \n",
       "L 1178 0 \n",
       "L 603 0 \n",
       "L 603 3500 \n",
       "z\n",
       "M 603 4863 \n",
       "L 1178 4863 \n",
       "L 1178 4134 \n",
       "L 603 4134 \n",
       "L 603 4863 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "      </defs>\n",
       "      <use xlink:href=\"#DejaVuSans-51\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-75\" x=\"78.710938\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-65\" x=\"142.089844\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-72\" x=\"203.613281\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-69\" x=\"244.726562\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-65\" x=\"272.509766\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-73\" x=\"334.033203\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"patch_3\">\n",
       "    <path d=\"M 34.240625 59.13 \n",
       "L 34.240625 36.81 \n",
       "\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_4\">\n",
       "    <path d=\"M 145.840625 59.13 \n",
       "L 145.840625 36.81 \n",
       "\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_5\">\n",
       "    <path d=\"M 34.240625 59.13 \n",
       "L 145.840625 59.13 \n",
       "\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_6\">\n",
       "    <path d=\"M 34.240625 36.81 \n",
       "L 145.840625 36.81 \n",
       "\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "  </g>\n",
       "  <g id=\"axes_2\">\n",
       "   <g id=\"patch_7\">\n",
       "    <path d=\"M 152.815625 88.74 \n",
       "L 156.892625 88.74 \n",
       "L 156.892625 7.2 \n",
       "L 152.815625 7.2 \n",
       "z\n",
       "\" style=\"fill: #ffffff\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_8\">\n",
       "    <path clip-path=\"url(#p673328f39b)\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.01; stroke-linejoin: miter\"/>\n",
       "   </g>\n",
       "   <image xlink:href=\"data:image/png;base64,\n",
       "iVBORw0KGgoAAAANSUhEUgAAAAQAAABSCAYAAABzJnWUAAAAm0lEQVR4nJ2SOw7CUBADHxL3vyoNBeyXFsZIFqQcje1NlMveb3venuuZPg6MMXbKdvxcqqASoDG7+TSG7xBjm5GSSLgVjXiDQO4Izvq39SsepHxC/g5/lLJDjA2C4mzKHQ5MjjF01q50obR7P0E1jYIhkZRSgmKk7UoNI8tZgiLA6hdDI0cM1yF3BE//YqA0xCB4SClB8FK5g6UvE4PMuTPJ8jwAAAAASUVORK5CYII=\" id=\"imagede191898c1\" transform=\"scale(1 -1)translate(0 -82)\" x=\"153\" y=\"-6\" width=\"4\" height=\"82\"/>\n",
       "   <g id=\"matplotlib.axis_3\">\n",
       "    <g id=\"ytick_3\">\n",
       "     <g id=\"line2d_5\">\n",
       "      <defs>\n",
       "       <path id=\"mde68a75cba\" d=\"M 0 0 \n",
       "L 3.5 0 \n",
       "\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use xlink:href=\"#mde68a75cba\" x=\"156.892625\" y=\"88.74\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_7\">\n",
       "      <!-- 0.0 -->\n",
       "      <g transform=\"translate(163.892625 92.539219)scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-2e\" d=\"M 684 794 \n",
       "L 1344 794 \n",
       "L 1344 0 \n",
       "L 684 0 \n",
       "L 684 794 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-30\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-30\" x=\"95.410156\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_4\">\n",
       "     <g id=\"line2d_6\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#mde68a75cba\" x=\"156.892625\" y=\"56.124\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_8\">\n",
       "      <!-- 0.2 -->\n",
       "      <g transform=\"translate(163.892625 59.923219)scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-32\" d=\"M 1228 531 \n",
       "L 3431 531 \n",
       "L 3431 0 \n",
       "L 469 0 \n",
       "L 469 531 \n",
       "Q 828 903 1448 1529 \n",
       "Q 2069 2156 2228 2338 \n",
       "Q 2531 2678 2651 2914 \n",
       "Q 2772 3150 2772 3378 \n",
       "Q 2772 3750 2511 3984 \n",
       "Q 2250 4219 1831 4219 \n",
       "Q 1534 4219 1204 4116 \n",
       "Q 875 4013 500 3803 \n",
       "L 500 4441 \n",
       "Q 881 4594 1212 4672 \n",
       "Q 1544 4750 1819 4750 \n",
       "Q 2544 4750 2975 4387 \n",
       "Q 3406 4025 3406 3419 \n",
       "Q 3406 3131 3298 2873 \n",
       "Q 3191 2616 2906 2266 \n",
       "Q 2828 2175 2409 1742 \n",
       "Q 1991 1309 1228 531 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-30\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-32\" x=\"95.410156\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_5\">\n",
       "     <g id=\"line2d_7\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#mde68a75cba\" x=\"156.892625\" y=\"23.508\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_9\">\n",
       "      <!-- 0.4 -->\n",
       "      <g transform=\"translate(163.892625 27.307219)scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-34\" d=\"M 2419 4116 \n",
       "L 825 1625 \n",
       "L 2419 1625 \n",
       "L 2419 4116 \n",
       "z\n",
       "M 2253 4666 \n",
       "L 3047 4666 \n",
       "L 3047 1625 \n",
       "L 3713 1625 \n",
       "L 3713 1100 \n",
       "L 3047 1100 \n",
       "L 3047 0 \n",
       "L 2419 0 \n",
       "L 2419 1100 \n",
       "L 313 1100 \n",
       "L 313 1709 \n",
       "L 2253 4666 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-30\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-34\" x=\"95.410156\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"LineCollection_1\"/>\n",
       "   <g id=\"patch_9\">\n",
       "    <path d=\"M 152.815625 88.74 \n",
       "L 154.854125 88.74 \n",
       "L 156.892625 88.74 \n",
       "L 156.892625 7.2 \n",
       "L 154.854125 7.2 \n",
       "L 152.815625 7.2 \n",
       "L 152.815625 88.74 \n",
       "z\n",
       "\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "  </g>\n",
       " </g>\n",
       " <defs>\n",
       "  <clipPath id=\"pf0fd76f9ae\">\n",
       "   <rect x=\"34.240625\" y=\"36.81\" width=\"111.6\" height=\"22.32\"/>\n",
       "  </clipPath>\n",
       "  <clipPath id=\"p673328f39b\">\n",
       "   <rect x=\"152.815625\" y=\"7.2\" width=\"4.077\" height=\"81.54\"/>\n",
       "  </clipPath>\n",
       " </defs>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<Figure size 180x180 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),\n",
    "                  xlabel='Keys', ylabel='Queries')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "911f8588",
   "metadata": {
    "origin_pos": 44
   },
   "source": [
    "## 小结\n",
    "\n",
    "* 将注意力汇聚的输出计算可以作为值的加权平均，选择不同的注意力评分函数会带来不同的注意力汇聚操作。\n",
    "* 当查询和键是不同长度的矢量时，可以使用可加性注意力评分函数。当它们的长度相同时，使用缩放的“点－积”注意力评分函数的计算效率更高。\n",
    "\n",
    "## 练习\n",
    "\n",
    "1. 修改小例子中的键，并且可视化注意力权重。可加性注意力和缩放的“点－积”注意力是否仍然产生相同的结果？为什么？\n",
    "1. 只使用矩阵乘法，能否为具有不同矢量长度的查询和键设计新的评分函数？\n",
    "1. 当查询和键具有相同的矢量长度时，矢量求和作为评分函数是否比“点－积”更好？为什么？\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "625a83ae",
   "metadata": {
    "origin_pos": 46,
    "tab": [
     "pytorch"
    ]
   },
   "source": [
    "[Discussions](https://discuss.d2l.ai/t/5752)\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  },
  "required_libs": []
 },
 "nbformat": 4,
 "nbformat_minor": 5
}